# -*- coding: utf-8 -*-

"""
Comment:

"""

import numpy as np
from math import e
from scipy.stats import uniform, gumbel_r, gumbel_l
from gurobipy import *
from gurobipy import Model
from gurobipy import quicksum
import copy
from scipy.stats import binom
import time
import matlab as ml

def mnl(os2,prices,segQ,segments,segScale,segNP):
    prob_MNL = np.zeros([os2.shape[0],os2.shape[1]+1,segments.shape[0]])
    for l in segments:
        for i in np.arange(os2.shape[0]):
            prob_MNL[i,os2.shape[1],l] = 1/(e**(segNP[0,l])+np.sum(os2[i,:]*e**(1/segScale[l]*(segQ[:,l]-np.squeeze(prices[:])))))
            for j in np.arange(os2.shape[1]):
                prob_MNL[i,j,l] = os2[i,j]*(e**(1/segScale[l]*(segQ[j,l]-prices[j][0])))/(e**(segNP[0,l])+np.sum(os2[i,:]*e**(1/segScale[l]*(segQ[:,l]-np.squeeze(prices)))))
            prob_MNL[i,-1,l] = e**(segNP[0,l])/(e**(segNP[0,l])+np.sum(os2[i,:]*e**(1/segScale[l]*(segQ[:,l]-np.squeeze(prices)))))
    return prob_MNL

def getRandomstreams(products,segments,segLambdas,segQualityindices,segScale,timeperiods,runs,nEvalWSK,segNP,seed):
    nTimeperiods = timeperiods.size
    nProducts = products.size
    nSegments = segments.size
    nRuns = runs.size
    evals = np.arange(nEvalWSK)
    randomSegments = np.zeros((nTimeperiods,nRuns,nEvalWSK),dtype=np.int32)
    randomQualityindices = np.zeros((nTimeperiods,nProducts,nRuns,nEvalWSK))    
    randomNopurch = np.zeros((nTimeperiods,nRuns,nEvalWSK))
    
    U = uniform.rvs(loc=0,scale=1,size=[nTimeperiods,nRuns,nEvalWSK],random_state=seed) #Which segment?
    Q = np.zeros((nTimeperiods,nProducts+1,nRuns,nSegments,nEvalWSK)) #Which quality index
    for l in segments: #in for-loop because scale depends on segment
        randomNumbersOfSegment = gumbel_r.rvs(loc=0,scale=segScale[l],size=[nTimeperiods,nProducts+1,nRuns,nEvalWSK],random_state=seed)
        Q[:,:,:,l,:] = randomNumbersOfSegment
    for ev in evals:
        for run in runs:
            for t in timeperiods:
                segLambdasInPeriod = segLambdas #allgemeinfültig bei verschiedenen lambdas?!?!?!
                if U[t,run,ev] >= np.sum(segLambdasInPeriod):
                    randomSegments[t,run,ev] = -1000
                    randomQualityindices[t,:,run,ev] = -np.inf                
                else:
                    cumulativeProbs = np.cumsum(segLambdasInPeriod)
                    l = np.argmin(cumulativeProbs<U[t,run,ev])
                    randomSegments[t,run,ev] = l        
                    randomQualityindices[t,:,run,ev] = Q[t,:nProducts,run,l,ev] + segQualityindices[:,l]
                    randomNopurch[t,run,ev] = Q[t,nProducts,run,l,ev] + segNP[0,l]*segScale[l]
    return randomSegments,randomQualityindices,randomNopurch

def getRandomstreams_Buyer_j(nE,inst,runn,nTimeperiods,nRuns,nEvalWSK,c,segScale2,nProducts,j,offerset,l,prices,segQ,x):
    samplesize = 1
    nQSamples = samplesize*100000
    sampleQualityindices = np.zeros((samplesize,nProducts))   
    sampleNopurch = np.zeros((samplesize))
    t = 0
    while t < samplesize:
        Q = gumbel_r.rvs(loc=0,scale=segScale2,size=[nQSamples,nProducts+1],random_state=1+x)
        for tt in np.arange(nQSamples):
            q = Q[tt,:-1]
            utilities = q - np.squeeze(prices)                         #calculate new utilitites with replaced price
            utilities[offerset==0] = -np.inf                                  #set all products not in OS to -inf   
            y = np.argmax(utilities)                             
            if y == j:       
                sampleQualityindices[t,:] = Q[tt,:nProducts] + segQ[:,l]
                sampleNopurch[t] = Q[tt,nProducts]
                t += 1
            if t == samplesize:
                break
            else:
                next
        x +=1
    return sampleQualityindices,sampleNopurch,x

def cdlp(offerSets,RS,capcon,cap,probMNL,nTimeperiods,products,eng):
    Qt = np.copy(np.squeeze(probMNL[:,:-1,:]))
    [timeslots,ofv] = eng.cdlp2(ml.int32(offerSets.tolist()),ml.double(RS.tolist()),ml.double(capcon.tolist()),ml.double(cap.tolist()),ml.double(Qt.tolist()),ml.double([nTimeperiods]),ml.int32(products.tolist()),nargout=2)
    ts = np.array(timeslots)
    return ts,ofv

def getRandomOSwithCDLP(nTimeperiods,nRuns,Ts,seed):
    mat = uniform.rvs(loc=0,scale=1,size=[nTimeperiods,nRuns],random_state=seed)
    OSNo = np.zeros(mat.shape)
    cumulativeProbs = np.cumsum(Ts/nTimeperiods)
    for run in np.arange(nRuns):
        for t in np.arange(nTimeperiods):
            l = np.argmin(cumulativeProbs<mat[t,run])
            OSNo[t,run] = l
    return OSNo

def efficientSets_largest_marginal_revenue_algorithm(Qt,RS,os):
    Qt = np.squeeze(Qt)
    RS = np.squeeze(RS)
    inSet = [0]
    effSets = list(inSet)
    posSets = np.squeeze(np.argwhere((Qt[:]>=Qt[inSet]) & (RS[:]>=RS[inSet])))
    posSets = posSets[posSets!=inSet]
    actSet = inSet
    while posSets.size != 0:
        actSet = np.argmax((RS[posSets]-RS[actSet])/(Qt[posSets]-Qt[actSet]))
        effSets.append(posSets[actSet])
        posSets = np.squeeze(np.argwhere((Qt[:]>=Qt[posSets[actSet]]) & (RS[:]>=RS[posSets[actSet]])))
        for i in effSets:
            posSets = posSets[posSets != i]
    return effSets

def calc_upsellpossibilities(uphierar,uniqSeg,counts2,nProducts,nRes,offerSets,prices,capcon,rem_cap):
    products = np.arange(nProducts)
    if nProducts == 2:
        nNewSegs = uniqSeg.shape[0]                                                    #number of new segments
        p_comp = np.zeros((nNewSegs,nProducts))
        i = 0
        uplist_M = np.zeros((nProducts,nProducts,nNewSegs))
        y_M = np.zeros((nProducts,nProducts,nNewSegs))
        uplist = []
        klist = []
        #segment specific prices - SINGLE-STEP CASCADING
        for nSeg in uniqSeg:
            u = []                                                 
            kli = []
            klis = 0
            #set price ranges
            for pro in products:
                if pro==0:       
                    p_comp[i,pro] = prices[nSeg[1].astype(int)]
                #prüfe auf Modulo, zudem höheres Produkt und erste Upsellmöglichkeit
                else:
                    p_comp[i,pro] = prices[pro][0]
                    u.append(pro)
                    kli.append(klis)
                    klis += 1
                    uplist_M[nSeg[1].astype(int),pro,i] = 1
                    y_M[nSeg[1].astype(int),pro,i] = counts2[i]
            i += 1
            uplist.append(u)
            klist.append(kli) 
    elif nProducts == 12:
        nNewSegs = uniqSeg.shape[0]   #number of new segments
        p_comp = np.zeros((nNewSegs,nProducts))
        i = 0
        uplist_M = np.zeros((nProducts,nProducts,nNewSegs))
        y_M = np.zeros((nProducts,nProducts,nNewSegs))
        uplist = []
        klist = []
        #segment specific prices - SINGLE-STEP CASCADING
        for nSeg in uniqSeg:
            u = []                                                 
            kli = []
            klis = 0
            if nSeg[1].astype(int)!=4:
                if nSeg[1].astype(int) !=5:
					#set price ranges
	                for pro in products:
	                    if pro==nSeg[1].astype(int):       
	                        p_comp[i,pro] = prices[nSeg[1].astype(int)]
	                        #prüfe auf Modulo, zudem höheres Produkt und erste Upsellmöglichkeit
	                    elif pro%2 == nSeg[1]%2 and pro>nSeg[1] and klis==0:
	                        p_comp[i,pro] = prices[pro][0]
	                        u.append(pro)
	                        kli.append(klis)
	                        klis += 1
	                        uplist_M[nSeg[1].astype(int),pro,i] = 1
	                        y_M[nSeg[1].astype(int),pro,i] = counts2[i]
	                        next
	                    #Ausschluss, wenn pro in höchster Klasse
	                    elif (capcon[pro,:]*rem_cap[0])[nRes-1]==0 and np.where(capcon[pro,:]==1)[0]==nRes:
	                        next
	                    else:
	                        p_comp[i,pro] = np.inf
            i += 1
            uplist.append(u)
            klist.append(kli)         
    else:
        if uphierar == 'F2FnS2S':
            nNewSegs = uniqSeg.shape[0]                                                    #number of new segments
            p_comp = np.zeros((nNewSegs,nProducts))
            i = 0
            uplist_M = np.zeros((nProducts,nProducts,nNewSegs))
            y_M = np.zeros((nProducts,nProducts,nNewSegs))
            uplist = []
            klist = []
            #segment specific prices - SINGLE-STEP CASCADING
            for nSeg in uniqSeg:
                u = []                                                 
                kli = []
                klis = 0
                #set price ranges
                for pro in products:
                    if pro==nSeg[1].astype(int):       
                        p_comp[i,pro] = prices[nSeg[1].astype(int)]
                    #prüfe auf Modulo, zudem höheres Produkt und erste Upsellmöglichkeit
                    elif pro%2 == nSeg[1]%2 and pro>nSeg[1] and klis==0:
                        p_comp[i,pro] = prices[pro][0]
                        u.append(pro)
                        kli.append(klis)
                        klis += 1
                        uplist_M[nSeg[1].astype(int),pro,i] = 1
                        y_M[nSeg[1].astype(int),pro,i] = counts2[i]
                        next
                    #Ausschluss, wenn pro in höchster Klasse
                    elif (capcon[pro,:]*rem_cap[0])[nRes-1]==0 and np.where(capcon[pro,:]==1)[0]==nRes:
                        next
                    else:
                        p_comp[i,pro] = np.inf
                i += 1
                uplist.append(u)
                klist.append(kli)
    return uplist, klist, p_comp, uplist_M, y_M
  